import pickle
import csv
import json
import jsonlines
import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import sys

random.seed(42)

def get_doc(docs):
    if docs in docmap:
        return docmap[docs]
    docmap[docs] = nlp(docs)
    return docmap[docs]

def save_maps(fname,mmap):
    with open(fname, 'wb+') as handle:
        pickle.dump(mmap, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
def load_map(fname):
    return pickle.load(open(fname,'rb'))

stop_words_more = ["a","an","the","how","who","what","which","where","when","is","was","that","there"]
def create_norm_reverseindex(mmap,stop_words_more):
    rev_map={}
    stop_words_more = set(stop_words_more)
    for k,v in tqdm(mmap.items(),ascii=True):
        for ent in v:
            ent = ent.lower()
            esplits = ent.split(" ")
            if len(esplits)>=1:
                tokens = set(esplits)-stop_words_more
                for ktok in tokens:
                    klist = rev_map.get(ktok,[])
                    klist.append(k)
                    rev_map[ktok]=klist
    return rev_map

def dedupe_map(mmap):
    new_map={}
    for k,v in tqdm(mmap.items(),ascii=True):
        new_map[k]=list(set(v))
    return new_map

def generate(omcs_npmap,ans_ents,q_ents,fname):
    omcs_sents=[]
    omcs_kws =[]

    for k,v in omcs_npmap.items():
        omcs_sents.append(k)
        omcs_kws.extend(v)
        
    omcs_sents = list(set(omcs_sents))
    omcs_kws = list(set(omcs_kws))
    print(f"{len(omcs_sents)} {len(omcs_kws)}")
    
    json.dump(omcs_sents,open("all_sci_sents.json","w"))
    json.dump(omcs_kws,open("all_sci_kws.json","w"))
    
    omcs_revmap = create_norm_reverseindex(omcs_npmap,stop_words_more)
    allowed= set(q_ents+ans_ents).intersection(omcs_revmap.keys())
    print(f"{len(allowed)}")
    
    with jsonlines.open(fname,"w") as ofd:
        for ans in tqdm(allowed,ascii=True):
            sents = omcs_revmap[ans]
#             print(len(sents))
            sents = random.sample(sents,min(len(sents),30))
            for sent1 in sents:
                for sent2 in sents:
                    if sent1==sent2:
                        continue
                    row={
                        "context":sent1,
                        "question":sent2,
                        "answer":ans,
#                         "neg_context":random.sample(omcs_sents,10),
#                         "neg_answer":random.sample(omcs_kws,10),
#                         "neg_question": random.sample(omcs_sents,10)
                    }
                    ofd.write(row)  

def generate_omcs(): 
    omcs_npmap = load_map("../data/triples/omcsnp_map.pickled")
    omcs_vpmap = load_map("../data/triples/omcsvp_map.pickled")
    ans_ents = [x.strip() for x in open("../scripts/comm_aents.csv").readlines()]
    q_ents = [x.strip() for x in open("../scripts/comm_qents.csv").readlines()]  

    generate(omcs_npmap,ans_ents,q_ents,"../data/raw_knw/omcs_k_np.jsonl")
    generate(omcs_vpmap,ans_ents,q_ents,"../data/raw_knw/omcs_k_vp.jsonl")
    
def generate_obqa():
    omcs_npmap = load_map("../data/triples/opnp_map.pickled")
    omcs_vpmap = load_map("../data/triples/opvp_map.pickled")
    ob_ans_ents = [x.strip() for x in open("../scripts/obqa_aents.csv").readlines()]
    ob_q_ents = [x.strip() for x in open("../scripts/obqa_qents.csv").readlines()]  
    arc_ans_ents = [x.strip() for x in open("../scripts/arc_aents.csv").readlines()]
    arc_q_ents = [x.strip() for x in open("../scripts/arc_qents.csv").readlines()] 
    qasc_ans_ents = [x.strip() for x in open("../scripts/qasc_aents_uniq.csv").readlines()]
    qasc_q_ents = [x.strip() for x in open("../scripts/qasc_qents_uniq.csv").readlines()]  
    ans_ents = list(set(ob_ans_ents+arc_ans_ents+qasc_ans_ents))
    q_ents = list(set(qasc_q_ents+ob_q_ents+arc_q_ents))
    ans_ents.remove("")
    q_ents.remove("")
    assert "" not in ans_ents
    assert "" not in q_ents

    generate(omcs_npmap,ans_ents,q_ents,"../data/raw_knw/obqa_k_np.jsonl")
    generate(omcs_vpmap,ans_ents,q_ents,"../data/raw_knw/obqa_k_vp.jsonl")
    omcs_npmap = load_map("../data/triples/qasc_np.pickled")
    omcs_vpmap = load_map("../data/triples/qasc_vp.pickled")
    generate(omcs_npmap,ans_ents,q_ents,"../data/raw_knw/obqa_qk_np.jsonl")
    generate(omcs_vpmap,ans_ents,q_ents,"../data/raw_knw/obqa_qk_vp.jsonl")
    
    
def generate_arc():
    omcs_npmap = load_map("../data/triples/qasc_np.pickled")
    omcs_vpmap = load_map("../data/triples/qasc_vp.pickled")
    ans_ents = [x.strip() for x in open("../scripts/arc_aents.csv").readlines()]
    q_ents = [x.strip() for x in open("../scripts/arc_qents.csv").readlines()]  

    generate(omcs_npmap,ans_ents,q_ents,"../data/raw_knw/arc_k_np.jsonl")
    generate(omcs_vpmap,ans_ents,q_ents,"../data/raw_knw/arc_k_vp.jsonl")
    
def generate_qasc():
    omcs_npmap = load_map("../data/triples/qasc_np.pickled")
    omcs_vpmap = load_map("../data/triples/qasc_vp.pickled")
    ans_ents = [x.strip() for x in open("../scripts/qasc_aents_uniq.csv").readlines()]
    q_ents = [x.strip() for x in open("../scripts/qasc_qents_uniq.csv").readlines()]  

    generate(omcs_npmap,ans_ents,q_ents,"../data/raw_knw/arc_k_np.jsonl")
    generate(omcs_vpmap,ans_ents,q_ents,"../data/raw_knw/arc_k_vp.jsonl")
    
inp = sys.argv[1]

fmap = {"arc":generate_arc,"qasc":generate_qasc,"obqa":generate_obqa,"omcs":generate_omcs}

fmap[inp]()